package drds.plus.parser.parser.dml.add;

import drds.plus.parser.abstract_syntax_tree.expression.Expression;
import drds.plus.parser.abstract_syntax_tree.expression.Pair;
import drds.plus.parser.abstract_syntax_tree.expression.primary.misc.Identifier;
import drds.plus.parser.abstract_syntax_tree.expression.primary.misc.RowValues;
import drds.plus.parser.abstract_syntax_tree.statement.InsertStatement;
import drds.plus.parser.abstract_syntax_tree.statement.Query;
import drds.plus.parser.lexer.Lexer;
import drds.plus.parser.lexer.Token;
import drds.plus.parser.parser.ExpressionParser;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;


public class InsertParser extends InsertReplaceParser {

    public InsertParser(Lexer lexer, ExpressionParser expressionParser) {
        super(lexer, expressionParser);
    }

    public InsertStatement insert() {
        match(Token.KW_INSERT);
        lexer.nextToken();
        match(Token.KW_INTO);
        lexer.nextToken();
        match(Token.IDENTIFIER);
        Identifier tableName = identifier();
        //
        List<Identifier> columnNameList;
        List<Expression> expressionList;//set
        List<RowValues> rowValuesList;
        Query query;
        //
        List<Pair<Identifier, Expression>> duplicateUpdate;
        if (lexer.token() == Token.KW_SET) {
            lexer.nextToken();
            columnNameList = new LinkedList<Identifier>();
            expressionList = new LinkedList<Expression>();
            do {
                Identifier identifier = identifier();
                match(Token.OP_EQUALS);
                lexer.nextToken();
                Expression expression = expressionParser.expression();
                //
                columnNameList.add(identifier);
                expressionList.add(expression);
            } while (lexer.token() == Token.PUNC_COMMA && lexer.nextToken() != Token.end_of_sql);//判断且获取下一个token
            rowValuesList = new ArrayList<RowValues>(1);
            rowValuesList.add(new RowValues(expressionList));
            duplicateUpdate = onDuplicateUpdate();
            return new InsertStatement(tableName, columnNameList, rowValuesList, duplicateUpdate);
        } else if (lexer.token() == Token.PUNC_LEFT_PAREN) {
            lexer.nextToken();
            columnNameList = buildIdentifierList();
            match(Token.PUNC_RIGHT_PAREN);
            lexer.nextToken();
            if (lexer.token() == Token.KW_VALUES) {
                lexer.nextToken();
                match(Token.PUNC_LEFT_PAREN);
                //lexer.nextToken();
                rowValuesList = rowDataList();
                duplicateUpdate = onDuplicateUpdate();
                return new InsertStatement(tableName, columnNameList, rowValuesList, duplicateUpdate);
            } else if (lexer.token() == Token.PUNC_LEFT_PAREN) {
                lexer.nextToken();
                query = selectStatement();//or trySelectStatement
                match(Token.PUNC_RIGHT_PAREN);
                lexer.nextToken();
                duplicateUpdate = onDuplicateUpdate();
                return new InsertStatement(tableName, columnNameList, query, duplicateUpdate);
            } else {
                throw error("不支持其他insert模式");
            }
        } else {
            throw error("unexpected token for insert: " + lexer.token());
        }

    }

    /**
     * @return null for not exist
     */
    private List<Pair<Identifier, Expression>> onDuplicateUpdate() {
        if (lexer.token() != Token.KW_ON) {
            return null;
        }
        lexer.nextToken();
        match("duplicate");
        lexer.nextToken();
        match(Token.KW_KEY);
        lexer.nextToken();
        match(Token.KW_UPDATE);
        lexer.nextToken();

        //
        Identifier identifier = identifier();
        match(Token.OP_EQUALS);
        lexer.nextToken();
        Expression expression = expressionParser.expression();
        //
        List<Pair<Identifier, Expression>> pairList = new LinkedList<Pair<Identifier, Expression>>();
        pairList.add(new Pair<Identifier, Expression>(identifier, expression));
        //
        if (lexer.token() == Token.PUNC_COMMA) {
            while (lexer.token() == Token.PUNC_COMMA) {
                lexer.nextToken();
                identifier = identifier();
                match(Token.OP_EQUALS);
                lexer.nextToken();
                expression = expressionParser.expression();
                pairList.add(new Pair<Identifier, Expression>(identifier, expression));
            }
        }
        pairList = new ArrayList<Pair<Identifier, Expression>>(pairList);
        return pairList;
    }
}
